{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "networks_seq2seq_nmt.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "source": [],
        "metadata": {
          "collapsed": false
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "f9ySOjrcc0Yp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bl9GdT7h0Hxk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WhwgQAn50EZp",
        "colab_type": "text"
      },
      "source": [
        "# TensorFlow Addons Networks : Sequence-to-Sequence NMT with Attention Mechanism\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "      <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ip0n8178Fuwm",
        "colab_type": "text"
      },
      "source": [
        "## Overview\n",
        "This notebook gives a brief introduction into the ***Sequence to Sequence Model Architecture***\n",
        "In this noteboook we broadly cover four essential topics necessary for Neural Machine Translation:\n",
        "\n",
        "\n",
        "* **Data cleaning**\n",
        "* **Data preparation**\n",
        "* **Neural Translation Model with Attention**\n",
        "* **Final Translation**\n",
        "\n",
        "The basic idea behind such a model though, is only the encoder-decoder architecture. These networks are usually used for a variety of tasks like text-summerization, Machine translation, Image Captioning, etc. This tutorial provideas a hands-on understanding of the concept, explaining the technical jargons wherever necessary. We focus on the task of Neural Machine Translation (NMT) which was the very first testbed for seq2seq models.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YNiadLKNLleD",
        "colab_type": "text"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "82GcQTsGf414",
        "colab_type": "text"
      },
      "source": [
        "## Additional Resources:\n",
        "\n",
        "These are a lst of resurces you must install in order to allow you to run this notebook:\n",
        "\n",
        "\n",
        "1. [German-English Dataset](http://www.manythings.org/anki/deu-eng.zip)\n",
        "\n",
        "\n",
        "The dataset should be downloaded, in order to compile this notebook, the embeddings can be used, as they are pretrained. Though, we carry out our own training here !!\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5OIlpST_6ga-",
        "colab_type": "code",
        "outputId": "bcfc583a-fa40-401f-a96e-d95f5cb954c8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153
        }
      },
      "source": [
        "#download data\n",
        "print(\"Downloading Dataset:\")\n",
        "!wget --quiet http://www.manythings.org/anki/deu-eng.zip\n",
        "!unzip deu-eng.zip"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading Dataset:\n",
            "Archive:  deu-eng.zip\n",
            "  inflating: deu.txt                 \n",
            "  inflating: _about.txt              \n",
            "Downloading Dataset:\n",
            "Archive:  deu-eng.zip\n",
            "replace deu.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: "
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "co6-YpBwL-4d",
        "colab_type": "code",
        "outputId": "c9f02df6-176e-4b83-e5af-fe8c9f688dff",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import csv\n",
        "import string\n",
        "import re\n",
        "from pickle import dump\n",
        "from unicodedata import normalize\n",
        "from numpy import array\n",
        "import itertools\n",
        "from pickle import load\n",
        "from tensorflow.keras.utils import to_categorical\n",
        "from keras.utils.vis_utils import plot_model\n",
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.layers import LSTM\n",
        "from tensorflow.keras.layers import Dense\n",
        "from tensorflow.keras.layers import Embedding\n",
        "from pickle import load\n",
        "from numpy import array\n",
        "from numpy import argmax\n",
        "import tensorflow as tf\n",
        "from keras.models import load_model\n",
        "from nltk.translate.bleu_score import corpus_bleu\n",
        "from sklearn.model_selection import train_test_split\n",
        "import tensorflow_addons as tfa"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using TensorFlow backend.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q7gjUT_9XSoj",
        "colab_type": "text"
      },
      "source": [
        "## Data Cleaning\n",
        "\n",
        "Our data set is a German-English translation dataset. It contains 152,820 pairs of English to German phases, one pair per line with a tab separating the language. These dataset though organized needs cleaning before we can work on it. This will enable us to remove unnecessary bumps that may come in during the training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6ZIu-TNqKFsd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        " # load doc into memory\n",
        "def load_documnet(filename):\n",
        "# open the file as read only\n",
        "  file = open(filename, mode='rt', encoding='utf-8')\n",
        "  # read all text\n",
        "  text = file.read()\n",
        "  # close the file\n",
        "  file.close()\n",
        "  return text\n",
        "\n",
        "# split a loaded document into sentences\n",
        "def doc_sep_pair(doc):\n",
        "  lines = doc.strip().split('\\n')\n",
        "  pairs = [line.split('\\t') for line in  lines]\n",
        "  return pairs\n",
        "\n",
        "# clean a list of lines\n",
        "def clean_sentences(lines):\n",
        "  cleaned = list()\n",
        "  re_print = re.compile('[^%s]' % re.escape(string.printable))\n",
        "  # prepare translation table \n",
        "  table = str.maketrans('', '', string.punctuation)\n",
        "  for pair in lines:\n",
        "    clean_pair = list()\n",
        "    for line in pair:\n",
        "      # normalizing unicode characters\n",
        "      line = normalize('NFD', line).encode('ascii', 'ignore')\n",
        "      line = line.decode('UTF-8')\n",
        "      # tokenize on white space\n",
        "      line = line.split()\n",
        "      # convert to lowercase\n",
        "      line = [word.lower() for word in line]\n",
        "      # removing punctuation\n",
        "      line = [word.translate(table) for word in line]\n",
        "      # removing non-printable chars form each token\n",
        "      line = [re_print.sub('', w) for w in line]\n",
        "      # removing tokens with numbers\n",
        "      line = [word for word in line if word.isalpha()]\n",
        "\n",
        "      line.insert(0,'<start> ')\n",
        "      line.append(' <end>')\n",
        "      # store as string\n",
        "      clean_pair.append(' '.join(line))\n",
        "    cleaned.append(clean_pair)\n",
        "  return array(cleaned)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eXpft1qQknO8",
        "colab_type": "text"
      },
      "source": [
        "## Saving the Cleaned Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GMxdlVU1X8yI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# load dataset\n",
        "filename = 'deu.txt' #change filename if necessary\n",
        "doc = load_documnet(filename)\n",
        "\n",
        "#clean sentences and save clean data\n",
        "pairs = doc_sep_pair(doc)\n",
        "clean_sentences = clean_sentences(pairs)\n",
        "raw_data = clean_sentences\n",
        "data = raw_data[:10000, :2] \n",
        "import numpy as np\n",
        "raw_data_en = list()\n",
        "raw_data_ge = list()\n",
        "for data1 in data:\n",
        "  raw_data_en.append(data1[0]),raw_data_ge.append(data1[1])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cfb66QxWYr6A",
        "colab_type": "text"
      },
      "source": [
        "## Tokenization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3oq60MBPSanQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "en_tokenizer.fit_on_texts(raw_data_en)\n",
        "\n",
        "data_en = en_tokenizer.texts_to_sequences(raw_data_en)\n",
        "data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,padding='post')\n",
        "\n",
        "ge_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "ge_tokenizer.fit_on_texts(raw_data_ge)\n",
        "\n",
        "data_ge = ge_tokenizer.texts_to_sequences(raw_data_ge)\n",
        "data_ge = tf.keras.preprocessing.sequence.pad_sequences(data_ge,padding='post')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XH5oSRNeSc1s",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def max_len(tensor):\n",
        "    #print( np.argmax([len(t) for t in tensor]))\n",
        "    return max( len(t) for t in tensor)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KdM37lNBGXAj",
        "colab_type": "text"
      },
      "source": [
        "## Model Parameters"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EfiBUJM2Et6C",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_train,  X_test, Y_train, Y_test = train_test_split(data_en,data_ge,test_size=0.2)\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = len(X_train)\n",
        "steps_per_epoch = BUFFER_SIZE//BATCH_SIZE\n",
        "embedding_dims = 256\n",
        "rnn_units = 1024\n",
        "dense_units = 1024\n",
        "Dtype = tf.float32   #used to initialize DecoderCell Zero state"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ff_jQHLhGqJU",
        "colab_type": "text"
      },
      "source": [
        "## Dataset Prepration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b__1hPHVFALO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Tx = max_len(data_en)\n",
        "Ty = max_len(data_ge)  \n",
        "\n",
        "input_vocab_size = len(en_tokenizer.word_index)+1  \n",
        "output_vocab_size = len(ge_tokenizer.word_index)+ 1\n",
        "dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "example_X, example_Y = next(iter(dataset))\n",
        "#print(example_X.shape) \n",
        "#print(example_Y.shape) "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UQRgJcYgapqE",
        "colab_type": "text"
      },
      "source": [
        "## Defining NMT Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sGdakRtjaokF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#ENCODER\n",
        "class EncoderNetwork(tf.keras.Model):\n",
        "    def __init__(self,input_vocab_size,embedding_dims, rnn_units ):\n",
        "        super().__init__()\n",
        "        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,\n",
        "                                                           output_dim=embedding_dims)\n",
        "        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True, \n",
        "                                                     return_state=True )\n",
        "    \n",
        "#DECODER\n",
        "class DecoderNetwork(tf.keras.Model):\n",
        "    def __init__(self,output_vocab_size, embedding_dims, rnn_units):\n",
        "        super().__init__()\n",
        "        self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,\n",
        "                                                           output_dim=embedding_dims) \n",
        "        self.dense_layer = tf.keras.layers.Dense(output_vocab_size)\n",
        "        self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)\n",
        "        # Sampler\n",
        "        self.sampler = tfa.seq2seq.sampler.TrainingSampler()\n",
        "        # Create attention mechanism with memory = None\n",
        "        self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])\n",
        "        self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)\n",
        "        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler= self.sampler,\n",
        "                                                output_layer=self.dense_layer)\n",
        "\n",
        "    def build_attention_mechanism(self, units,memory, memory_sequence_length):\n",
        "        return tfa.seq2seq.LuongAttention(units, memory = memory, \n",
        "                                          memory_sequence_length=memory_sequence_length)\n",
        "        #return tfa.seq2seq.BahdanauAttention(units, memory = memory, memory_sequence_length=memory_sequence_length)\n",
        "\n",
        "    # wrap decodernn cell  \n",
        "    def build_rnn_cell(self, batch_size ):\n",
        "        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,\n",
        "                                                attention_layer_size=dense_units)\n",
        "        return rnn_cell\n",
        "    \n",
        "    def build_decoder_initial_state(self, batch_size, encoder_state,Dtype):\n",
        "        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size, \n",
        "                                                                dtype = Dtype)\n",
        "        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) \n",
        "        return decoder_initial_state\n",
        "\n",
        "encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)\n",
        "decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)\n",
        "optimizer = tf.keras.optimizers.Adam()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NPwcfddTa0oB",
        "colab_type": "text"
      },
      "source": [
        "## Initializing Training functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x1BEqVyra2jW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def loss_function(y_pred, y):\n",
        "   \n",
        "    #shape of y [batch_size, ty]\n",
        "    #shape of y_pred [batch_size, Ty, output_vocab_size] \n",
        "    sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,\n",
        "                                                                                  reduction='none')\n",
        "    loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)\n",
        "    mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1\n",
        "    mask = tf.cast(mask, dtype=loss.dtype)\n",
        "    loss = mask* loss\n",
        "    loss = tf.reduce_mean(loss)\n",
        "    return loss\n",
        "\n",
        "\n",
        "def train_step(input_batch, output_batch,encoder_initial_cell_state):\n",
        "    #initialize loss = 0\n",
        "    loss = 0\n",
        "    with tf.GradientTape() as tape:\n",
        "        encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)\n",
        "        a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp, \n",
        "                                                        initial_state =encoder_initial_cell_state)\n",
        "\n",
        "        #[last step activations,last memory_state] of encoder passed as input to decoder Network\n",
        "        \n",
        "         \n",
        "        # Prepare correct Decoder input & output sequence data\n",
        "        decoder_input = output_batch[:,:-1] # ignore <end>\n",
        "        #compare logits with timestepped +1 version of decoder_input\n",
        "        decoder_output = output_batch[:,1:] #ignore <start>\n",
        "\n",
        "\n",
        "        # Decoder Embeddings\n",
        "        decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)\n",
        "\n",
        "        #Setting up decoder memory from encoder output and Zero State for AttentionWrapperState\n",
        "        decoderNetwork.attention_mechanism.setup_memory(a)\n",
        "        decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,\n",
        "                                                                           encoder_state=[a_tx, c_tx],\n",
        "                                                                           Dtype=tf.float32)\n",
        "        \n",
        "        #BasicDecoderOutput        \n",
        "        outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state=decoder_initial_state,\n",
        "                                               sequence_length=BATCH_SIZE*[Ty-1])\n",
        "\n",
        "        logits = outputs.rnn_output\n",
        "        #Calculate loss\n",
        "\n",
        "        loss = loss_function(logits, decoder_output)\n",
        "\n",
        "    #Returns the list of all layer variables / weights.\n",
        "    variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables  \n",
        "    # differentiate loss wrt variables\n",
        "    gradients = tape.gradient(loss, variables)\n",
        "\n",
        "    #grads_and_vars – List of(gradient, variable) pairs.\n",
        "    grads_and_vars = zip(gradients,variables)\n",
        "    optimizer.apply_gradients(grads_and_vars)\n",
        "    return loss"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71Lkdx6GFb3A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#RNN LSTM hidden and memory state initializer\n",
        "def initialize_initial_state():\n",
        "        return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v5uzLcu2bNX3",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PvfD2SknWrt6",
        "colab_type": "code",
        "outputId": "84f672cf-ec99-4f96-8654-47ee8b9ee1f5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "epochs = 15\n",
        "for i in range(1, epochs+1):\n",
        "\n",
        "    encoder_initial_cell_state = initialize_initial_state()\n",
        "    total_loss = 0.0\n",
        "\n",
        "    for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):\n",
        "        batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)\n",
        "        total_loss += batch_loss\n",
        "        if (batch+1)%5 == 0:\n",
        "            print(\"total loss: {} epoch {} batch {} \".format(batch_loss.numpy(), i, batch+1))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total loss: 3.9475886821746826 epoch 1 batch 5 \n",
            "total loss: 2.912675380706787 epoch 1 batch 10 \n",
            "total loss: 2.2815799713134766 epoch 1 batch 15 \n",
            "total loss: 2.183105230331421 epoch 1 batch 20 \n",
            "total loss: 2.2029659748077393 epoch 1 batch 25 \n",
            "total loss: 2.1441750526428223 epoch 1 batch 30 \n",
            "total loss: 2.0407521724700928 epoch 1 batch 35 \n",
            "total loss: 2.010983943939209 epoch 1 batch 40 \n",
            "total loss: 2.073960542678833 epoch 1 batch 45 \n",
            "total loss: 1.990903615951538 epoch 1 batch 50 \n",
            "total loss: 2.074843406677246 epoch 1 batch 55 \n",
            "total loss: 2.011075496673584 epoch 1 batch 60 \n",
            "total loss: 1.9887497425079346 epoch 1 batch 65 \n",
            "total loss: 1.892216682434082 epoch 1 batch 70 \n",
            "total loss: 1.9781100749969482 epoch 1 batch 75 \n",
            "total loss: 1.858034372329712 epoch 1 batch 80 \n",
            "total loss: 1.7822014093399048 epoch 1 batch 85 \n",
            "total loss: 1.8063167333602905 epoch 1 batch 90 \n",
            "total loss: 1.7766847610473633 epoch 1 batch 95 \n",
            "total loss: 1.8258825540542603 epoch 1 batch 100 \n",
            "total loss: 1.8362295627593994 epoch 1 batch 105 \n",
            "total loss: 1.7136967182159424 epoch 1 batch 110 \n",
            "total loss: 1.9544591903686523 epoch 1 batch 115 \n",
            "total loss: 1.8494930267333984 epoch 1 batch 120 \n",
            "total loss: 1.6849143505096436 epoch 1 batch 125 \n",
            "total loss: 1.6526412963867188 epoch 2 batch 5 \n",
            "total loss: 1.674989938735962 epoch 2 batch 10 \n",
            "total loss: 1.5851391553878784 epoch 2 batch 15 \n",
            "total loss: 1.6815704107284546 epoch 2 batch 20 \n",
            "total loss: 1.6417633295059204 epoch 2 batch 25 \n",
            "total loss: 1.697832703590393 epoch 2 batch 30 \n",
            "total loss: 1.7252113819122314 epoch 2 batch 35 \n",
            "total loss: 1.537140965461731 epoch 2 batch 40 \n",
            "total loss: 1.6764580011367798 epoch 2 batch 45 \n",
            "total loss: 1.453371286392212 epoch 2 batch 50 \n",
            "total loss: 1.6771221160888672 epoch 2 batch 55 \n",
            "total loss: 1.5605512857437134 epoch 2 batch 60 \n",
            "total loss: 1.6059997081756592 epoch 2 batch 65 \n",
            "total loss: 1.467176079750061 epoch 2 batch 70 \n",
            "total loss: 1.609415054321289 epoch 2 batch 75 \n",
            "total loss: 1.5329309701919556 epoch 2 batch 80 \n",
            "total loss: 1.6187160015106201 epoch 2 batch 85 \n",
            "total loss: 1.5867189168930054 epoch 2 batch 90 \n",
            "total loss: 1.5069472789764404 epoch 2 batch 95 \n",
            "total loss: 1.64217209815979 epoch 2 batch 100 \n",
            "total loss: 1.5193077325820923 epoch 2 batch 105 \n",
            "total loss: 1.5160114765167236 epoch 2 batch 110 \n",
            "total loss: 1.516736626625061 epoch 2 batch 115 \n",
            "total loss: 1.4899837970733643 epoch 2 batch 120 \n",
            "total loss: 1.6041948795318604 epoch 2 batch 125 \n",
            "total loss: 1.502350926399231 epoch 3 batch 5 \n",
            "total loss: 1.360275149345398 epoch 3 batch 10 \n",
            "total loss: 1.3449124097824097 epoch 3 batch 15 \n",
            "total loss: 1.378374457359314 epoch 3 batch 20 \n",
            "total loss: 1.4500166177749634 epoch 3 batch 25 \n",
            "total loss: 1.3589526414871216 epoch 3 batch 30 \n",
            "total loss: 1.3434584140777588 epoch 3 batch 35 \n",
            "total loss: 1.2667752504348755 epoch 3 batch 40 \n",
            "total loss: 1.4497849941253662 epoch 3 batch 45 \n",
            "total loss: 1.5071169137954712 epoch 3 batch 50 \n",
            "total loss: 1.344785451889038 epoch 3 batch 55 \n",
            "total loss: 1.4199110269546509 epoch 3 batch 60 \n",
            "total loss: 1.3664555549621582 epoch 3 batch 65 \n",
            "total loss: 1.3798571825027466 epoch 3 batch 70 \n",
            "total loss: 1.4127501249313354 epoch 3 batch 75 \n",
            "total loss: 1.3040590286254883 epoch 3 batch 80 \n",
            "total loss: 1.4017330408096313 epoch 3 batch 85 \n",
            "total loss: 1.4011389017105103 epoch 3 batch 90 \n",
            "total loss: 1.295208215713501 epoch 3 batch 95 \n",
            "total loss: 1.3480514287948608 epoch 3 batch 100 \n",
            "total loss: 1.2870609760284424 epoch 3 batch 105 \n",
            "total loss: 1.4333269596099854 epoch 3 batch 110 \n",
            "total loss: 1.3486473560333252 epoch 3 batch 115 \n",
            "total loss: 1.26927649974823 epoch 3 batch 120 \n",
            "total loss: 1.2078845500946045 epoch 3 batch 125 \n",
            "total loss: 1.1198420524597168 epoch 4 batch 5 \n",
            "total loss: 1.0763131380081177 epoch 4 batch 10 \n",
            "total loss: 1.1939853429794312 epoch 4 batch 15 \n",
            "total loss: 1.2020100355148315 epoch 4 batch 20 \n",
            "total loss: 1.0719928741455078 epoch 4 batch 25 \n",
            "total loss: 1.124625325202942 epoch 4 batch 30 \n",
            "total loss: 1.1307220458984375 epoch 4 batch 35 \n",
            "total loss: 1.1385688781738281 epoch 4 batch 40 \n",
            "total loss: 1.1286941766738892 epoch 4 batch 45 \n",
            "total loss: 1.1118738651275635 epoch 4 batch 50 \n",
            "total loss: 1.0924361944198608 epoch 4 batch 55 \n",
            "total loss: 1.2378876209259033 epoch 4 batch 60 \n",
            "total loss: 1.1472713947296143 epoch 4 batch 65 \n",
            "total loss: 1.1867095232009888 epoch 4 batch 70 \n",
            "total loss: 1.1062105894088745 epoch 4 batch 75 \n",
            "total loss: 1.0883691310882568 epoch 4 batch 80 \n",
            "total loss: 1.1805391311645508 epoch 4 batch 85 \n",
            "total loss: 1.3100593090057373 epoch 4 batch 90 \n",
            "total loss: 1.199307918548584 epoch 4 batch 95 \n",
            "total loss: 1.1042678356170654 epoch 4 batch 100 \n",
            "total loss: 1.0394186973571777 epoch 4 batch 105 \n",
            "total loss: 1.1532882452011108 epoch 4 batch 110 \n",
            "total loss: 1.0915677547454834 epoch 4 batch 115 \n",
            "total loss: 1.0750417709350586 epoch 4 batch 120 \n",
            "total loss: 1.0421984195709229 epoch 4 batch 125 \n",
            "total loss: 1.0400830507278442 epoch 5 batch 5 \n",
            "total loss: 1.0274797677993774 epoch 5 batch 10 \n",
            "total loss: 0.8721328973770142 epoch 5 batch 15 \n",
            "total loss: 0.9050451517105103 epoch 5 batch 20 \n",
            "total loss: 0.9007365107536316 epoch 5 batch 25 \n",
            "total loss: 0.8890656232833862 epoch 5 batch 30 \n",
            "total loss: 0.8846011161804199 epoch 5 batch 35 \n",
            "total loss: 0.8554547429084778 epoch 5 batch 40 \n",
            "total loss: 1.1025922298431396 epoch 5 batch 45 \n",
            "total loss: 0.9758970141410828 epoch 5 batch 50 \n",
            "total loss: 1.0573564767837524 epoch 5 batch 55 \n",
            "total loss: 0.9744541049003601 epoch 5 batch 60 \n",
            "total loss: 0.9071753621101379 epoch 5 batch 65 \n",
            "total loss: 0.970922589302063 epoch 5 batch 70 \n",
            "total loss: 0.9922286868095398 epoch 5 batch 75 \n",
            "total loss: 0.8951885104179382 epoch 5 batch 80 \n",
            "total loss: 1.0515273809432983 epoch 5 batch 85 \n",
            "total loss: 0.9692702293395996 epoch 5 batch 90 \n",
            "total loss: 0.8851386904716492 epoch 5 batch 95 \n",
            "total loss: 1.0359522104263306 epoch 5 batch 100 \n",
            "total loss: 0.9581290483474731 epoch 5 batch 105 \n",
            "total loss: 0.9426918029785156 epoch 5 batch 110 \n",
            "total loss: 0.9563409686088562 epoch 5 batch 115 \n",
            "total loss: 0.9106627702713013 epoch 5 batch 120 \n",
            "total loss: 0.9571183919906616 epoch 5 batch 125 \n",
            "total loss: 0.6938820481300354 epoch 6 batch 5 \n",
            "total loss: 0.760671854019165 epoch 6 batch 10 \n",
            "total loss: 0.699514627456665 epoch 6 batch 15 \n",
            "total loss: 0.6691784858703613 epoch 6 batch 20 \n",
            "total loss: 0.8158406019210815 epoch 6 batch 25 \n",
            "total loss: 0.7383745908737183 epoch 6 batch 30 \n",
            "total loss: 0.7447091341018677 epoch 6 batch 35 \n",
            "total loss: 0.7862703800201416 epoch 6 batch 40 \n",
            "total loss: 0.8322451710700989 epoch 6 batch 45 \n",
            "total loss: 0.8715308904647827 epoch 6 batch 50 \n",
            "total loss: 0.706369161605835 epoch 6 batch 55 \n",
            "total loss: 0.7995638251304626 epoch 6 batch 60 \n",
            "total loss: 0.8098785281181335 epoch 6 batch 65 \n",
            "total loss: 0.6516961455345154 epoch 6 batch 70 \n",
            "total loss: 0.7424792647361755 epoch 6 batch 75 \n",
            "total loss: 0.7396417856216431 epoch 6 batch 80 \n",
            "total loss: 0.7362191677093506 epoch 6 batch 85 \n",
            "total loss: 0.9558976292610168 epoch 6 batch 90 \n",
            "total loss: 0.8189946413040161 epoch 6 batch 95 \n",
            "total loss: 0.7554519176483154 epoch 6 batch 100 \n",
            "total loss: 0.772563099861145 epoch 6 batch 105 \n",
            "total loss: 0.8337545394897461 epoch 6 batch 110 \n",
            "total loss: 0.7600473761558533 epoch 6 batch 115 \n",
            "total loss: 0.7708126902580261 epoch 6 batch 120 \n",
            "total loss: 0.6998305320739746 epoch 6 batch 125 \n",
            "total loss: 0.5774018168449402 epoch 7 batch 5 \n",
            "total loss: 0.6558392643928528 epoch 7 batch 10 \n",
            "total loss: 0.5383725762367249 epoch 7 batch 15 \n",
            "total loss: 0.6437508463859558 epoch 7 batch 20 \n",
            "total loss: 0.594805121421814 epoch 7 batch 25 \n",
            "total loss: 0.5795590281486511 epoch 7 batch 30 \n",
            "total loss: 0.7567883729934692 epoch 7 batch 35 \n",
            "total loss: 0.5663882493972778 epoch 7 batch 40 \n",
            "total loss: 0.6014893054962158 epoch 7 batch 45 \n",
            "total loss: 0.5960389971733093 epoch 7 batch 50 \n",
            "total loss: 0.633935809135437 epoch 7 batch 55 \n",
            "total loss: 0.6122901439666748 epoch 7 batch 60 \n",
            "total loss: 0.6862307786941528 epoch 7 batch 65 \n",
            "total loss: 0.7035883665084839 epoch 7 batch 70 \n",
            "total loss: 0.7250910997390747 epoch 7 batch 75 \n",
            "total loss: 0.6099406480789185 epoch 7 batch 80 \n",
            "total loss: 0.5820813179016113 epoch 7 batch 85 \n",
            "total loss: 0.5596920251846313 epoch 7 batch 90 \n",
            "total loss: 0.6520225405693054 epoch 7 batch 95 \n",
            "total loss: 0.6929486393928528 epoch 7 batch 100 \n",
            "total loss: 0.6704176664352417 epoch 7 batch 105 \n",
            "total loss: 0.6779621839523315 epoch 7 batch 110 \n",
            "total loss: 0.6607205271720886 epoch 7 batch 115 \n",
            "total loss: 0.5835480690002441 epoch 7 batch 120 \n",
            "total loss: 0.6930114030838013 epoch 7 batch 125 \n",
            "total loss: 0.4390392303466797 epoch 8 batch 5 \n",
            "total loss: 0.502900242805481 epoch 8 batch 10 \n",
            "total loss: 0.44953665137290955 epoch 8 batch 15 \n",
            "total loss: 0.5513165593147278 epoch 8 batch 20 \n",
            "total loss: 0.5228049159049988 epoch 8 batch 25 \n",
            "total loss: 0.48368993401527405 epoch 8 batch 30 \n",
            "total loss: 0.4797203540802002 epoch 8 batch 35 \n",
            "total loss: 0.5822355151176453 epoch 8 batch 40 \n",
            "total loss: 0.4232334494590759 epoch 8 batch 45 \n",
            "total loss: 0.5870698094367981 epoch 8 batch 50 \n",
            "total loss: 0.48263269662857056 epoch 8 batch 55 \n",
            "total loss: 0.44463014602661133 epoch 8 batch 60 \n",
            "total loss: 0.49221715331077576 epoch 8 batch 65 \n",
            "total loss: 0.5247334837913513 epoch 8 batch 70 \n",
            "total loss: 0.6095311045646667 epoch 8 batch 75 \n",
            "total loss: 0.5857402086257935 epoch 8 batch 80 \n",
            "total loss: 0.46884751319885254 epoch 8 batch 85 \n",
            "total loss: 0.5228506326675415 epoch 8 batch 90 \n",
            "total loss: 0.46329981088638306 epoch 8 batch 95 \n",
            "total loss: 0.5708974003791809 epoch 8 batch 100 \n",
            "total loss: 0.5332533121109009 epoch 8 batch 105 \n",
            "total loss: 0.532862663269043 epoch 8 batch 110 \n",
            "total loss: 0.5066767334938049 epoch 8 batch 115 \n",
            "total loss: 0.5123209357261658 epoch 8 batch 120 \n",
            "total loss: 0.49092260003089905 epoch 8 batch 125 \n",
            "total loss: 0.3689466118812561 epoch 9 batch 5 \n",
            "total loss: 0.3250238299369812 epoch 9 batch 10 \n",
            "total loss: 0.44425806403160095 epoch 9 batch 15 \n",
            "total loss: 0.3422010838985443 epoch 9 batch 20 \n",
            "total loss: 0.3302001357078552 epoch 9 batch 25 \n",
            "total loss: 0.3376121520996094 epoch 9 batch 30 \n",
            "total loss: 0.43184441328048706 epoch 9 batch 35 \n",
            "total loss: 0.40924182534217834 epoch 9 batch 40 \n",
            "total loss: 0.38222527503967285 epoch 9 batch 45 \n",
            "total loss: 0.4478159546852112 epoch 9 batch 50 \n",
            "total loss: 0.4593771994113922 epoch 9 batch 55 \n",
            "total loss: 0.3862895369529724 epoch 9 batch 60 \n",
            "total loss: 0.40882641077041626 epoch 9 batch 65 \n",
            "total loss: 0.4312051236629486 epoch 9 batch 70 \n",
            "total loss: 0.41449132561683655 epoch 9 batch 75 \n",
            "total loss: 0.45340195298194885 epoch 9 batch 80 \n",
            "total loss: 0.4121376574039459 epoch 9 batch 85 \n",
            "total loss: 0.5007123947143555 epoch 9 batch 90 \n",
            "total loss: 0.4919680655002594 epoch 9 batch 95 \n",
            "total loss: 0.4845644533634186 epoch 9 batch 100 \n",
            "total loss: 0.5462281107902527 epoch 9 batch 105 \n",
            "total loss: 0.3803269863128662 epoch 9 batch 110 \n",
            "total loss: 0.4410593509674072 epoch 9 batch 115 \n",
            "total loss: 0.44156259298324585 epoch 9 batch 120 \n",
            "total loss: 0.48795849084854126 epoch 9 batch 125 \n",
            "total loss: 0.2677317261695862 epoch 10 batch 5 \n",
            "total loss: 0.2849716544151306 epoch 10 batch 10 \n",
            "total loss: 0.2650394141674042 epoch 10 batch 15 \n",
            "total loss: 0.3369045853614807 epoch 10 batch 20 \n",
            "total loss: 0.27701759338378906 epoch 10 batch 25 \n",
            "total loss: 0.2801435589790344 epoch 10 batch 30 \n",
            "total loss: 0.2140922248363495 epoch 10 batch 35 \n",
            "total loss: 0.3308884799480438 epoch 10 batch 40 \n",
            "total loss: 0.3573286831378937 epoch 10 batch 45 \n",
            "total loss: 0.3585323691368103 epoch 10 batch 50 \n",
            "total loss: 0.3311135172843933 epoch 10 batch 55 \n",
            "total loss: 0.4792550206184387 epoch 10 batch 60 \n",
            "total loss: 0.3666520416736603 epoch 10 batch 65 \n",
            "total loss: 0.3469542860984802 epoch 10 batch 70 \n",
            "total loss: 0.39160677790641785 epoch 10 batch 75 \n",
            "total loss: 0.375261127948761 epoch 10 batch 80 \n",
            "total loss: 0.34272241592407227 epoch 10 batch 85 \n",
            "total loss: 0.43078818917274475 epoch 10 batch 90 \n",
            "total loss: 0.2668665647506714 epoch 10 batch 95 \n",
            "total loss: 0.37373679876327515 epoch 10 batch 100 \n",
            "total loss: 0.3685140311717987 epoch 10 batch 105 \n",
            "total loss: 0.3151918351650238 epoch 10 batch 110 \n",
            "total loss: 0.34442174434661865 epoch 10 batch 115 \n",
            "total loss: 0.4334893822669983 epoch 10 batch 120 \n",
            "total loss: 0.3609371781349182 epoch 10 batch 125 \n",
            "total loss: 0.22984731197357178 epoch 11 batch 5 \n",
            "total loss: 0.2418794184923172 epoch 11 batch 10 \n",
            "total loss: 0.2832423150539398 epoch 11 batch 15 \n",
            "total loss: 0.24624614417552948 epoch 11 batch 20 \n",
            "total loss: 0.2362026870250702 epoch 11 batch 25 \n",
            "total loss: 0.24753396213054657 epoch 11 batch 30 \n",
            "total loss: 0.30054670572280884 epoch 11 batch 35 \n",
            "total loss: 0.2899046242237091 epoch 11 batch 40 \n",
            "total loss: 0.2667545676231384 epoch 11 batch 45 \n",
            "total loss: 0.2646659314632416 epoch 11 batch 50 \n",
            "total loss: 0.2119644582271576 epoch 11 batch 55 \n",
            "total loss: 0.2534087002277374 epoch 11 batch 60 \n",
            "total loss: 0.2593185305595398 epoch 11 batch 65 \n",
            "total loss: 0.3010985553264618 epoch 11 batch 70 \n",
            "total loss: 0.30156993865966797 epoch 11 batch 75 \n",
            "total loss: 0.31427207589149475 epoch 11 batch 80 \n",
            "total loss: 0.30148079991340637 epoch 11 batch 85 \n",
            "total loss: 0.31685349345207214 epoch 11 batch 90 \n",
            "total loss: 0.2858041822910309 epoch 11 batch 95 \n",
            "total loss: 0.23358872532844543 epoch 11 batch 100 \n",
            "total loss: 0.3077571392059326 epoch 11 batch 105 \n",
            "total loss: 0.2969404458999634 epoch 11 batch 110 \n",
            "total loss: 0.36080026626586914 epoch 11 batch 115 \n",
            "total loss: 0.2823699116706848 epoch 11 batch 120 \n",
            "total loss: 0.28497424721717834 epoch 11 batch 125 \n",
            "total loss: 0.2739666998386383 epoch 12 batch 5 \n",
            "total loss: 0.247772216796875 epoch 12 batch 10 \n",
            "total loss: 0.21159425377845764 epoch 12 batch 15 \n",
            "total loss: 0.24877581000328064 epoch 12 batch 20 \n",
            "total loss: 0.23003773391246796 epoch 12 batch 25 \n",
            "total loss: 0.22304224967956543 epoch 12 batch 30 \n",
            "total loss: 0.2357630431652069 epoch 12 batch 35 \n",
            "total loss: 0.24652035534381866 epoch 12 batch 40 \n",
            "total loss: 0.24459195137023926 epoch 12 batch 45 \n",
            "total loss: 0.2198447287082672 epoch 12 batch 50 \n",
            "total loss: 0.22670245170593262 epoch 12 batch 55 \n",
            "total loss: 0.2391890287399292 epoch 12 batch 60 \n",
            "total loss: 0.2453366070985794 epoch 12 batch 65 \n",
            "total loss: 0.21846142411231995 epoch 12 batch 70 \n",
            "total loss: 0.25742220878601074 epoch 12 batch 75 \n",
            "total loss: 0.2598118185997009 epoch 12 batch 80 \n",
            "total loss: 0.2885677218437195 epoch 12 batch 85 \n",
            "total loss: 0.32734522223472595 epoch 12 batch 90 \n",
            "total loss: 0.3083980083465576 epoch 12 batch 95 \n",
            "total loss: 0.3234527111053467 epoch 12 batch 100 \n",
            "total loss: 0.29528990387916565 epoch 12 batch 105 \n",
            "total loss: 0.27330103516578674 epoch 12 batch 110 \n",
            "total loss: 0.2824668288230896 epoch 12 batch 115 \n",
            "total loss: 0.26833924651145935 epoch 12 batch 120 \n",
            "total loss: 0.3090164065361023 epoch 12 batch 125 \n",
            "total loss: 0.18143436312675476 epoch 13 batch 5 \n",
            "total loss: 0.24107468128204346 epoch 13 batch 10 \n",
            "total loss: 0.1723310351371765 epoch 13 batch 15 \n",
            "total loss: 0.2374371737241745 epoch 13 batch 20 \n",
            "total loss: 0.18838974833488464 epoch 13 batch 25 \n",
            "total loss: 0.1868618130683899 epoch 13 batch 30 \n",
            "total loss: 0.2468196451663971 epoch 13 batch 35 \n",
            "total loss: 0.18816381692886353 epoch 13 batch 40 \n",
            "total loss: 0.2015218436717987 epoch 13 batch 45 \n",
            "total loss: 0.17972926795482635 epoch 13 batch 50 \n",
            "total loss: 0.19488045573234558 epoch 13 batch 55 \n",
            "total loss: 0.179433211684227 epoch 13 batch 60 \n",
            "total loss: 0.18720710277557373 epoch 13 batch 65 \n",
            "total loss: 0.26200735569000244 epoch 13 batch 70 \n",
            "total loss: 0.2021588832139969 epoch 13 batch 75 \n",
            "total loss: 0.2547597587108612 epoch 13 batch 80 \n",
            "total loss: 0.2753807604312897 epoch 13 batch 85 \n",
            "total loss: 0.27378445863723755 epoch 13 batch 90 \n",
            "total loss: 0.24202470481395721 epoch 13 batch 95 \n",
            "total loss: 0.22158583998680115 epoch 13 batch 100 \n",
            "total loss: 0.22244706749916077 epoch 13 batch 105 \n",
            "total loss: 0.23681640625 epoch 13 batch 110 \n",
            "total loss: 0.2990795373916626 epoch 13 batch 115 \n",
            "total loss: 0.2641446888446808 epoch 13 batch 120 \n",
            "total loss: 0.23204472661018372 epoch 13 batch 125 \n",
            "total loss: 0.1728627234697342 epoch 14 batch 5 \n",
            "total loss: 0.17328783869743347 epoch 14 batch 10 \n",
            "total loss: 0.20071764290332794 epoch 14 batch 15 \n",
            "total loss: 0.15985815227031708 epoch 14 batch 20 \n",
            "total loss: 0.16585272550582886 epoch 14 batch 25 \n",
            "total loss: 0.17702646553516388 epoch 14 batch 30 \n",
            "total loss: 0.19213533401489258 epoch 14 batch 35 \n",
            "total loss: 0.1678582727909088 epoch 14 batch 40 \n",
            "total loss: 0.17316825687885284 epoch 14 batch 45 \n",
            "total loss: 0.18272583186626434 epoch 14 batch 50 \n",
            "total loss: 0.2643834352493286 epoch 14 batch 55 \n",
            "total loss: 0.1914786398410797 epoch 14 batch 60 \n",
            "total loss: 0.24789147078990936 epoch 14 batch 65 \n",
            "total loss: 0.20449848473072052 epoch 14 batch 70 \n",
            "total loss: 0.20783527195453644 epoch 14 batch 75 \n",
            "total loss: 0.20063571631908417 epoch 14 batch 80 \n",
            "total loss: 0.22110366821289062 epoch 14 batch 85 \n",
            "total loss: 0.27967819571495056 epoch 14 batch 90 \n",
            "total loss: 0.21627402305603027 epoch 14 batch 95 \n",
            "total loss: 0.2716841697692871 epoch 14 batch 100 \n",
            "total loss: 0.26125216484069824 epoch 14 batch 105 \n",
            "total loss: 0.28036823868751526 epoch 14 batch 110 \n",
            "total loss: 0.2875978350639343 epoch 14 batch 115 \n",
            "total loss: 0.24142596125602722 epoch 14 batch 120 \n",
            "total loss: 0.2443583458662033 epoch 14 batch 125 \n",
            "total loss: 0.12856344878673553 epoch 15 batch 5 \n",
            "total loss: 0.19890321791172028 epoch 15 batch 10 \n",
            "total loss: 0.21472203731536865 epoch 15 batch 15 \n",
            "total loss: 0.1831301748752594 epoch 15 batch 20 \n",
            "total loss: 0.18663254380226135 epoch 15 batch 25 \n",
            "total loss: 0.1720810979604721 epoch 15 batch 30 \n",
            "total loss: 0.1868405044078827 epoch 15 batch 35 \n",
            "total loss: 0.22495588660240173 epoch 15 batch 40 \n",
            "total loss: 0.20775218307971954 epoch 15 batch 45 \n",
            "total loss: 0.1730271875858307 epoch 15 batch 50 \n",
            "total loss: 0.2216344177722931 epoch 15 batch 55 \n",
            "total loss: 0.21534466743469238 epoch 15 batch 60 \n",
            "total loss: 0.16373391449451447 epoch 15 batch 65 \n",
            "total loss: 0.20450153946876526 epoch 15 batch 70 \n",
            "total loss: 0.22405973076820374 epoch 15 batch 75 \n",
            "total loss: 0.21495337784290314 epoch 15 batch 80 \n",
            "total loss: 0.19726605713367462 epoch 15 batch 85 \n",
            "total loss: 0.1876198798418045 epoch 15 batch 90 \n",
            "total loss: 0.19518446922302246 epoch 15 batch 95 \n",
            "total loss: 0.2388114184141159 epoch 15 batch 100 \n",
            "total loss: 0.19776995480060577 epoch 15 batch 105 \n",
            "total loss: 0.20737934112548828 epoch 15 batch 110 \n",
            "total loss: 0.20728227496147156 epoch 15 batch 115 \n",
            "total loss: 0.19053156673908234 epoch 15 batch 120 \n",
            "total loss: 0.17176872491836548 epoch 15 batch 125 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nDyK-EGqbN5r",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y98sfom7SuGy",
        "colab_type": "code",
        "outputId": "c0f424cd-4e8e-4bf1-92bd-c7d8cebb2994",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 315
        }
      },
      "source": [
        "#In this section we evaluate our model on a raw_input converted to german, for this the entire sentence has to be passed\n",
        "#through the length of the model, for this we use greedsampler to run through the decoder\n",
        "#and the final embedding matrix trained on the data is used to generate embeddings\n",
        "input_raw='how are you'\n",
        "\n",
        "# We have a transcript file containing English-German pairs\n",
        "# Preprocess X\n",
        "input_lines = ['<start> '+input_raw+'']\n",
        "input_sequences = [[en_tokenizer.word_index[w] for w in line.split(' ')] for line in input_lines]\n",
        "input_sequences = tf.keras.preprocessing.sequence.pad_sequences(input_sequences,\n",
        "                                                                maxlen=Tx, padding='post')\n",
        "inp = tf.convert_to_tensor(input_sequences)\n",
        "#print(inp.shape)\n",
        "inference_batch_size = input_sequences.shape[0]\n",
        "encoder_initial_cell_state = [tf.zeros((inference_batch_size, rnn_units)),\n",
        "                              tf.zeros((inference_batch_size, rnn_units))]\n",
        "encoder_emb_inp = encoderNetwork.encoder_embedding(inp)\n",
        "a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp,\n",
        "                                                initial_state =encoder_initial_cell_state)\n",
        "print('a_tx :',a_tx.shape)\n",
        "print('c_tx :', c_tx.shape)\n",
        "\n",
        "start_tokens = tf.fill([inference_batch_size],ge_tokenizer.word_index['<start>'])\n",
        "\n",
        "end_token = ge_tokenizer.word_index['<end>']\n",
        "\n",
        "greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()\n",
        "\n",
        "decoder_input = tf.expand_dims([ge_tokenizer.word_index['<start>']]* inference_batch_size,1)\n",
        "decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)\n",
        "\n",
        "decoder_instance = tfa.seq2seq.BasicDecoder(cell = decoderNetwork.rnn_cell, sampler = greedy_sampler,\n",
        "                                            output_layer=decoderNetwork.dense_layer)\n",
        "decoderNetwork.attention_mechanism.setup_memory(a)\n",
        "#pass [ last step activations , encoder memory_state ] as input to decoder for LSTM\n",
        "print(\"decoder_initial_state = [a_tx, c_tx] :\",np.array([a_tx, c_tx]).shape)\n",
        "decoder_initial_state = decoderNetwork.build_decoder_initial_state(inference_batch_size,\n",
        "                                                                   encoder_state=[a_tx, c_tx],\n",
        "                                                                   Dtype=tf.float32)\n",
        "print(\"\\nCompared to simple encoder-decoder without attention, the decoder_initial_state \\\n",
        " is an AttentionWrapperState object containing s_prev tensors and context and alignment vector \\n \")\n",
        "print(\"decoder initial state shape :\",np.array(decoder_initial_state).shape)\n",
        "print(\"decoder_initial_state tensor \\n\", decoder_initial_state)\n",
        "\n",
        "# Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.\n",
        "# One heuristic is to decode up to two times the source sentence lengths.\n",
        "maximum_iterations = tf.round(tf.reduce_max(Tx) * 2)\n",
        "\n",
        "#initialize inference decoder\n",
        "decoder_embedding_matrix = decoderNetwork.decoder_embedding.variables[0] \n",
        "(first_finished, first_inputs,first_state) = decoder_instance.initialize(decoder_embedding_matrix,\n",
        "                             start_tokens = start_tokens,\n",
        "                             end_token=end_token,\n",
        "                             initial_state = decoder_initial_state)\n",
        "#print( first_finished.shape)\n",
        "print(\"\\nfirst_inputs returns the same decoder_input i.e. embedding of  <start> :\",first_inputs.shape)\n",
        "print(\"start_index_emb_avg \", tf.reduce_sum(tf.reduce_mean(first_inputs, axis=0))) # mean along the batch\n",
        "\n",
        "inputs = first_inputs\n",
        "state = first_state  \n",
        "predictions = np.empty((inference_batch_size,0), dtype = np.int32)                                                                             \n",
        "for j in range(maximum_iterations):\n",
        "    outputs, next_state, next_inputs, finished = decoder_instance.step(j,inputs,state)\n",
        "    inputs = next_inputs\n",
        "    state = next_state\n",
        "    outputs = np.expand_dims(outputs.sample_id,axis = -1)\n",
        "    predictions = np.append(predictions, outputs, axis = -1)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a_tx : (1, 1024)\n",
            "c_tx : (1, 1024)\n",
            "decoder_initial_state = [a_tx, c_tx] : (2, 1, 1024)\n",
            "\n",
            "Compared to simple encoder-decoder without attention, the decoder_initial_state  is an AttentionWrapperState object containing s_prev tensors and context and alignment vector \n",
            " \n",
            "decoder initial state shape : (6,)\n",
            "decoder_initial_state tensor \n",
            " AttentionWrapperState(cell_state=[<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=\n",
            "array([[ 0.00032512,  0.00170071,  0.06353987, ..., -0.01063914,\n",
            "        -0.01768327, -0.08082021]], dtype=float32)>, <tf.Tensor: shape=(1, 1024), dtype=float32, numpy=\n",
            "array([[ 0.00110459,  0.00727613,  0.17626816, ..., -0.02890627,\n",
            "        -0.06944194, -0.15541168]], dtype=float32)>], attention=<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, time=<tf.Tensor: shape=(), dtype=int32, numpy=0>, alignments=<tf.Tensor: shape=(1, 7), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>, alignment_history=(), attention_state=<tf.Tensor: shape=(1, 7), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>)\n",
            "\n",
            "first_inputs returns the same decoder_input i.e. embedding of  <start> : (1, 256)\n",
            "start_index_emb_avg  tf.Tensor(0.108049706, shape=(), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iodjSItQds1t",
        "colab_type": "text"
      },
      "source": [
        "## Final Translation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K6aWFB5IWlH2",
        "colab_type": "code",
        "outputId": "a2374a75-d3d7-444e-f3b1-32cfde0c57a9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        }
      },
      "source": [
        "#prediction based on our sentence earlier\n",
        "print(\"English Sentence:\")\n",
        "print(input_raw)\n",
        "print(\"\\nGerman Translation:\")\n",
        "for i in range(len(predictions)):\n",
        "    line = predictions[i,:]\n",
        "    seq = list(itertools.takewhile( lambda index: index !=2, line))\n",
        "    print(\" \".join( [ge_tokenizer.index_word[w] for w in seq]))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "English Sentence:\n",
            "how are you\n",
            "\n",
            "German Translation:\n",
            "wie arrogant\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g6Av-oPWvRc4",
        "colab_type": "text"
      },
      "source": [
        "### The accuracy can be improved by implementing:\n",
        "* Beam Search or Lexicon Search\n",
        "* Bi-directional encoder-decoder model "
      ]
    }
  ]
}
